{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flair.data import Sentence\n",
    "from flair.models import SequenceTagger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = Sentence('I love Berlin .')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tagger = SequenceTagger.load('ner')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"I love Berlin .\" - 4 Tokens]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tagger.predict(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence: \"I love Berlin .\" - 4 Tokens\n"
     ]
    }
   ],
   "source": [
    "print(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LOC-span [3]: \"Berlin\"\n"
     ]
    }
   ],
   "source": [
    "for entity in sentence.get_spans('ner'):\n",
    "    print(entity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence: \"The grass is green .\" - 5 Tokens\n"
     ]
    }
   ],
   "source": [
    "sentence = Sentence('The grass is green .')\n",
    "print(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token: 4 green\n"
     ]
    }
   ],
   "source": [
    "print(sentence.get_token(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token: 4 green\n"
     ]
    }
   ],
   "source": [
    "print(sentence[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token: 1 The\n",
      "Token: 2 grass\n",
      "Token: 3 is\n",
      "Token: 4 green\n",
      "Token: 5 .\n"
     ]
    }
   ],
   "source": [
    "for token in sentence:\n",
    "    print(token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence: \"The grass is green .\" - 5 Tokens\n"
     ]
    }
   ],
   "source": [
    "sentence = Sentence('The grass is green.', use_tokenizer = True)\n",
    "print(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence[3].add_tag('ner', 'color')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The grass is green <color> .\n"
     ]
    }
   ],
   "source": [
    "print(sentence.to_tagged_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flair.data import Label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "tag: Label = sentence[3].get_tag('ner')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\"Token: 4 green\" is tagged as \"color\" with confidence score \"1.0\"\n"
     ]
    }
   ],
   "source": [
    "print(f'\"{sentence[3]}\" is tagged as \"{tag.value}\" with confidence score \"{tag.score}\"')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = Sentence('France is the current World Cup winner.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence.add_label('sports')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence.add_labels(['sports', 'world cup'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence: \"France is the current World Cup winner.\" - 7 Tokens\n"
     ]
    }
   ],
   "source": [
    "print(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sports (1.0)\n",
      "sports (1.0)\n",
      "world cup (1.0)\n"
     ]
    }
   ],
   "source": [
    "for label in sentence.labels:\n",
    "    print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence: \"France is the current World Cup winner\" - 7 Tokens\n",
      "sports (1.0)\n",
      "world cup (1.0)\n"
     ]
    }
   ],
   "source": [
    "sentence = Sentence('France is the current World Cup winner', labels=['sports', 'world cup'])\n",
    "print(sentence)\n",
    "for label in sentence.labels:\n",
    "    print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flair.models import SequenceTagger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "tagger = SequenceTagger.load('ner')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = Sentence('George Washington went to Washington .')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"George Washington went to Washington .\" - 6 Tokens]"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tagger.predict(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "George <B-PER> Washington <E-PER> went to Washington <S-LOC> .\n"
     ]
    }
   ],
   "source": [
    "print(sentence.to_tagged_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PER-span [1,2]: \"George Washington\"\n",
      "LOC-span [5]: \"Washington\"\n"
     ]
    }
   ],
   "source": [
    "for entity in sentence.get_spans('ner'):\n",
    "    print(entity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'text': 'George Washington went to Washington .', 'labels': [], 'entities': [{'text': 'George Washington', 'start_pos': 0, 'end_pos': 17, 'type': 'PER', 'confidence': 0.999337375164032}, {'text': 'Washington', 'start_pos': 26, 'end_pos': 36, 'type': 'LOC', 'confidence': 0.9998500347137451}]}\n"
     ]
    }
   ],
   "source": [
    "print(sentence.to_dict(tag_type='ner'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "tagger = SequenceTagger.load('frame')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence_1 = Sentence('George returned to Berlin to return his hat .')\n",
    "sentence_2 = Sentence('He had a look at different hats .')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"George returned to Berlin to return his hat .\" - 9 Tokens]"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tagger.predict(sentence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"He had a look at different hats .\" - 8 Tokens]"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tagger.predict(sentence_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "George returned <return.01> to Berlin to return <return.02> his hat .\n",
      "He had <have.LV> a look <look.01> at different hats .\n"
     ]
    }
   ],
   "source": [
    "print(sentence_1.to_tagged_string())\n",
    "print(sentence_2.to_tagged_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = 'This is a sentence. This is another sentence. I love Berlin.'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "from segtok.segmenter import split_single"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = [Sentence(sent, use_tokenizer=True) for sent in split_single(text)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"This is a sentence .\" - 5 Tokens,\n",
       " Sentence: \"This is another sentence .\" - 5 Tokens,\n",
       " Sentence: \"I love Berlin .\" - 4 Tokens]"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"This is a sentence .\" - 5 Tokens,\n",
       " Sentence: \"This is another sentence .\" - 5 Tokens,\n",
       " Sentence: \"I love Berlin .\" - 4 Tokens]"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tagger: SequenceTagger = SequenceTagger.load('ner')\n",
    "tagger.predict(sentences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flair.embeddings import WordEmbeddings\n",
    "\n",
    "# init embedding\n",
    "glove_embedding = WordEmbeddings('glove')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = Sentence('The grass is green .')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"The grass is green .\" - 5 Tokens]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove_embedding.embed(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token: 1 The\n",
      "tensor([-0.0382, -0.2449,  0.7281, -0.3996,  0.0832,  0.0440, -0.3914,  0.3344,\n",
      "        -0.5755,  0.0875,  0.2879, -0.0673,  0.3091, -0.2638, -0.1323, -0.2076,\n",
      "         0.3340, -0.3385, -0.3174, -0.4834,  0.1464, -0.3730,  0.3458,  0.0520,\n",
      "         0.4495, -0.4697,  0.0263, -0.5415, -0.1552, -0.1411, -0.0397,  0.2828,\n",
      "         0.1439,  0.2346, -0.3102,  0.0862,  0.2040,  0.5262,  0.1716, -0.0824,\n",
      "        -0.7179, -0.4153,  0.2033, -0.1276,  0.4137,  0.5519,  0.5791, -0.3348,\n",
      "        -0.3656, -0.5486, -0.0629,  0.2658,  0.3020,  0.9977, -0.8048, -3.0243,\n",
      "         0.0125, -0.3694,  2.2167,  0.7220, -0.2498,  0.9214,  0.0345,  0.4674,\n",
      "         1.1079, -0.1936, -0.0746,  0.2335, -0.0521, -0.2204,  0.0572, -0.1581,\n",
      "        -0.3080, -0.4162,  0.3797,  0.1501, -0.5321, -0.2055, -1.2526,  0.0716,\n",
      "         0.7056,  0.4974, -0.4206,  0.2615, -1.5380, -0.3022, -0.0734, -0.2831,\n",
      "         0.3710, -0.2522,  0.0162, -0.0171, -0.3898,  0.8742, -0.7257, -0.5106,\n",
      "        -0.5203, -0.1459,  0.8278,  0.2706])\n",
      "Token: 2 grass\n",
      "tensor([-0.8135,  0.9404, -0.2405, -0.1350,  0.0557,  0.3363,  0.0802, -0.1015,\n",
      "        -0.5478, -0.3537,  0.0734,  0.2587,  0.1987, -0.1433,  0.2507,  0.4281,\n",
      "         0.1950,  0.5346,  0.7424,  0.0578, -0.3178,  0.9436,  0.8145, -0.0824,\n",
      "         0.6166,  0.7284, -0.3262, -1.3641,  0.1232,  0.5373, -0.5123,  0.0246,\n",
      "         1.0822, -0.2296,  0.6039,  0.5541, -0.9610,  0.4803,  0.0022,  0.5591,\n",
      "        -0.1637, -0.8468,  0.0741, -0.6216,  0.0260, -0.5162, -0.0525, -0.1418,\n",
      "        -0.0161, -0.4972, -0.5534, -0.4037,  0.5096,  1.0276, -0.0840, -1.1179,\n",
      "         0.3226,  0.4928,  0.9488,  0.2040,  0.5388,  0.8397, -0.0689,  0.3136,\n",
      "         1.0450, -0.2267, -0.0896, -0.6427,  0.6443, -1.1001, -0.0096,  0.2668,\n",
      "        -0.3230, -0.6065,  0.0479, -0.1664,  0.8571,  0.2335,  0.2539,  1.2546,\n",
      "         0.5472, -0.1980, -0.7186,  0.2076, -0.2587, -0.3650,  0.0834,  0.6932,\n",
      "         0.1574,  1.0931,  0.0913, -1.3773, -0.2717,  0.7071,  0.1872, -0.3307,\n",
      "        -0.2836,  0.1030,  1.2228,  0.8374])\n",
      "Token: 3 is\n",
      "tensor([-0.5426,  0.4148,  1.0322, -0.4024,  0.4669,  0.2182, -0.0749,  0.4733,\n",
      "         0.0810, -0.2208, -0.1281, -0.1144,  0.5089,  0.1157,  0.0282, -0.3628,\n",
      "         0.4382,  0.0475,  0.2028,  0.4986, -0.1007,  0.1327,  0.1697,  0.1165,\n",
      "         0.3135,  0.2571,  0.0928, -0.5683, -0.5297, -0.0515, -0.6733,  0.9253,\n",
      "         0.2693,  0.2273,  0.6636,  0.2622,  0.1972,  0.2609,  0.1877, -0.3454,\n",
      "        -0.4263,  0.1398,  0.5634, -0.5691,  0.1240, -0.1289,  0.7248, -0.2610,\n",
      "        -0.2631, -0.4360,  0.0789, -0.8415,  0.5160,  1.3997, -0.7646, -3.1453,\n",
      "        -0.2920, -0.3125,  1.5129,  0.5243,  0.2146,  0.4245, -0.0884, -0.1780,\n",
      "         1.1876,  0.1058,  0.7657,  0.2191,  0.3582, -0.1164,  0.0933, -0.6248,\n",
      "        -0.2190,  0.2180,  0.7406, -0.4374,  0.1434,  0.1472, -1.1605, -0.0505,\n",
      "         0.1268, -0.0144, -0.9868, -0.0913, -1.2054, -0.1197,  0.0478, -0.5400,\n",
      "         0.5246, -0.7096, -0.3253, -0.1346, -0.4131,  0.3343, -0.0072,  0.3225,\n",
      "        -0.0442, -1.2969,  0.7622,  0.4635])\n",
      "Token: 4 green\n",
      "tensor([-6.7907e-01,  3.4908e-01, -2.3984e-01, -9.9652e-01,  7.3782e-01,\n",
      "        -6.5911e-04,  2.8010e-01,  1.7287e-02, -3.6063e-01,  3.6955e-02,\n",
      "        -4.0395e-01,  2.4092e-02,  2.8958e-01,  4.0497e-01,  6.9992e-01,\n",
      "         2.5269e-01,  8.0350e-01,  4.9370e-02,  1.5562e-01, -6.3286e-03,\n",
      "        -2.9414e-01,  1.4728e-01,  1.8977e-01, -5.1791e-01,  3.6986e-01,\n",
      "         7.4582e-01,  8.2689e-02, -7.2601e-01, -4.0939e-01, -9.7822e-02,\n",
      "        -1.4096e-01,  7.1121e-01,  6.1933e-01, -2.5014e-01,  4.2250e-01,\n",
      "         4.8458e-01, -5.1915e-01,  7.7125e-01,  3.6685e-01,  4.9652e-01,\n",
      "        -4.1298e-02, -1.4683e+00,  2.0038e-01,  1.8591e-01,  4.9860e-02,\n",
      "        -1.7523e-01, -3.5528e-01,  9.4153e-01, -1.1898e-01, -5.1903e-01,\n",
      "        -1.1887e-02, -3.9186e-01, -1.7479e-01,  9.3451e-01, -5.8931e-01,\n",
      "        -2.7701e+00,  3.4522e-01,  8.6533e-01,  1.0808e+00, -1.0291e-01,\n",
      "        -9.1220e-02,  5.5092e-01, -3.9473e-01,  5.3676e-01,  1.0383e+00,\n",
      "        -4.0658e-01,  2.4590e-01, -2.6797e-01, -2.6036e-01, -1.4151e-01,\n",
      "        -1.2022e-01,  1.6234e-01, -7.4320e-01, -6.4728e-01,  4.7133e-02,\n",
      "         5.1642e-01,  1.9898e-01,  2.3919e-01,  1.2550e-01,  2.2471e-01,\n",
      "         8.2613e-01,  7.8328e-02, -5.7020e-01,  2.3934e-02, -1.5410e-01,\n",
      "        -2.5739e-01,  4.1262e-01, -4.6967e-01,  8.7914e-01,  7.2629e-01,\n",
      "         5.3862e-02, -1.1575e+00, -4.7835e-01,  2.0139e-01, -1.0051e+00,\n",
      "         1.1515e-01, -9.6609e-01,  1.2960e-01,  1.8388e-01, -3.0383e-02])\n",
      "Token: 5 .\n",
      "tensor([-0.3398,  0.2094,  0.4635, -0.6479, -0.3838,  0.0380,  0.1713,  0.1598,\n",
      "         0.4662, -0.0192,  0.4148, -0.3435,  0.2687,  0.0446,  0.4213, -0.4103,\n",
      "         0.1546,  0.0222, -0.6465,  0.2526,  0.0431, -0.1945,  0.4652,  0.4565,\n",
      "         0.6859,  0.0913,  0.2188, -0.7035,  0.1679, -0.3508, -0.1263,  0.6638,\n",
      "        -0.2582,  0.0365, -0.1361,  0.4025,  0.1429,  0.3813, -0.1228, -0.4589,\n",
      "        -0.2528, -0.3043, -0.1121, -0.2618, -0.2248, -0.4455,  0.2991, -0.8561,\n",
      "        -0.1450, -0.4909,  0.0083, -0.1749,  0.2752,  1.4401, -0.2124, -2.8435,\n",
      "        -0.2796, -0.4572,  1.6386,  0.7881, -0.5526,  0.6500,  0.0864,  0.3901,\n",
      "         1.0632, -0.3538,  0.4833,  0.3460,  0.8417,  0.0987, -0.2421, -0.2705,\n",
      "         0.0453, -0.4015,  0.1139,  0.0062,  0.0367,  0.0185, -1.0213, -0.2081,\n",
      "         0.6407, -0.0688, -0.5864,  0.3348, -1.1432, -0.1148, -0.2509, -0.4591,\n",
      "        -0.0968, -0.1795, -0.0634, -0.6741, -0.0689,  0.5360, -0.8777,  0.3180,\n",
      "        -0.3924, -0.2339,  0.4730, -0.0288])\n"
     ]
    }
   ],
   "source": [
    "for token in sentence:\n",
    "    print(token)\n",
    "    print(token.embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sentence: \"The grass is green .\" - 5 Tokens]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from flair.embeddings import CharacterEmbeddings\n",
    "\n",
    "# init embedding\n",
    "embedding = CharacterEmbeddings()\n",
    "\n",
    "# create a sentence\n",
    "sentence = Sentence('The grass is green .')\n",
    "\n",
    "# embed words in sentence\n",
    "embedding.embed(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flair.embeddings import WordEmbeddings, CharacterEmbeddings\n",
    "\n",
    "glove_embedding = WordEmbeddings('glove')\n",
    "character_embeddings = CharacterEmbeddings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flair.embeddings import StackedEmbeddings\n",
    "\n",
    "stacked_embeddings = StackedEmbeddings(\n",
    "    embeddings = [glove_embedding, character_embeddings])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = Sentence('The grass is green .')\n",
    "stacked_embeddings.embed(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "StackedEmbeddings(\n",
       "  (list_embedding_0): WordEmbeddings()\n",
       "  (list_embedding_1): CharacterEmbeddings(\n",
       "    (char_embedding): Embedding(275, 25)\n",
       "    (char_rnn): LSTM(25, 25, bidirectional=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stacked_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token: 1 The\n",
      "tensor([-3.8194e-02, -2.4487e-01,  7.2812e-01, -3.9961e-01,  8.3172e-02,\n",
      "         4.3953e-02, -3.9141e-01,  3.3440e-01, -5.7545e-01,  8.7459e-02,\n",
      "         2.8787e-01, -6.7310e-02,  3.0906e-01, -2.6384e-01, -1.3231e-01,\n",
      "        -2.0757e-01,  3.3395e-01, -3.3848e-01, -3.1743e-01, -4.8336e-01,\n",
      "         1.4640e-01, -3.7304e-01,  3.4577e-01,  5.2041e-02,  4.4946e-01,\n",
      "        -4.6971e-01,  2.6280e-02, -5.4155e-01, -1.5518e-01, -1.4107e-01,\n",
      "        -3.9722e-02,  2.8277e-01,  1.4393e-01,  2.3464e-01, -3.1021e-01,\n",
      "         8.6173e-02,  2.0397e-01,  5.2624e-01,  1.7164e-01, -8.2378e-02,\n",
      "        -7.1787e-01, -4.1531e-01,  2.0335e-01, -1.2763e-01,  4.1367e-01,\n",
      "         5.5187e-01,  5.7908e-01, -3.3477e-01, -3.6559e-01, -5.4857e-01,\n",
      "        -6.2892e-02,  2.6584e-01,  3.0205e-01,  9.9775e-01, -8.0481e-01,\n",
      "        -3.0243e+00,  1.2540e-02, -3.6942e-01,  2.2167e+00,  7.2201e-01,\n",
      "        -2.4978e-01,  9.2136e-01,  3.4514e-02,  4.6745e-01,  1.1079e+00,\n",
      "        -1.9358e-01, -7.4575e-02,  2.3353e-01, -5.2062e-02, -2.2044e-01,\n",
      "         5.7162e-02, -1.5806e-01, -3.0798e-01, -4.1625e-01,  3.7972e-01,\n",
      "         1.5006e-01, -5.3212e-01, -2.0550e-01, -1.2526e+00,  7.1624e-02,\n",
      "         7.0565e-01,  4.9744e-01, -4.2063e-01,  2.6148e-01, -1.5380e+00,\n",
      "        -3.0223e-01, -7.3438e-02, -2.8312e-01,  3.7104e-01, -2.5217e-01,\n",
      "         1.6215e-02, -1.7099e-02, -3.8984e-01,  8.7424e-01, -7.2569e-01,\n",
      "        -5.1058e-01, -5.2028e-01, -1.4590e-01,  8.2780e-01,  2.7062e-01,\n",
      "         6.3108e-03, -2.5734e-01, -1.0127e-02, -3.2274e-02,  2.8811e-03,\n",
      "        -5.4744e-03,  1.3090e-01, -5.9536e-02,  1.6519e-01, -4.3329e-02,\n",
      "        -1.3551e-01, -3.7349e-02,  1.3456e-01,  2.2238e-01, -1.4624e-01,\n",
      "         2.5576e-01, -3.7204e-02,  4.0126e-01,  2.4248e-01,  1.0761e-01,\n",
      "        -1.1448e-01, -7.3020e-02, -2.2720e-01,  2.5039e-02, -2.8237e-01,\n",
      "         1.4390e-01, -1.1083e-01, -1.3574e-01, -4.3571e-02, -8.2276e-02,\n",
      "         3.1797e-01,  1.0562e-01,  1.8005e-01, -2.1153e-01, -1.6619e-01,\n",
      "        -6.8136e-02,  2.4039e-01, -3.3454e-02, -2.0433e-01,  1.0134e-01,\n",
      "        -2.1497e-01,  1.5355e-02,  2.0674e-01,  1.4811e-01, -9.1302e-02,\n",
      "         4.0148e-02,  8.6612e-02,  4.1856e-02,  8.1967e-02, -7.4910e-02],\n",
      "       grad_fn=<CatBackward>)\n",
      "Token: 2 grass\n",
      "tensor([-0.8135,  0.9404, -0.2405, -0.1350,  0.0557,  0.3363,  0.0802, -0.1015,\n",
      "        -0.5478, -0.3537,  0.0734,  0.2587,  0.1987, -0.1433,  0.2507,  0.4281,\n",
      "         0.1950,  0.5346,  0.7424,  0.0578, -0.3178,  0.9436,  0.8145, -0.0824,\n",
      "         0.6166,  0.7284, -0.3262, -1.3641,  0.1232,  0.5373, -0.5123,  0.0246,\n",
      "         1.0822, -0.2296,  0.6039,  0.5541, -0.9610,  0.4803,  0.0022,  0.5591,\n",
      "        -0.1637, -0.8468,  0.0741, -0.6216,  0.0260, -0.5162, -0.0525, -0.1418,\n",
      "        -0.0161, -0.4972, -0.5534, -0.4037,  0.5096,  1.0276, -0.0840, -1.1179,\n",
      "         0.3226,  0.4928,  0.9488,  0.2040,  0.5388,  0.8397, -0.0689,  0.3136,\n",
      "         1.0450, -0.2267, -0.0896, -0.6427,  0.6443, -1.1001, -0.0096,  0.2668,\n",
      "        -0.3230, -0.6065,  0.0479, -0.1664,  0.8571,  0.2335,  0.2539,  1.2546,\n",
      "         0.5472, -0.1980, -0.7186,  0.2076, -0.2587, -0.3650,  0.0834,  0.6932,\n",
      "         0.1574,  1.0931,  0.0913, -1.3773, -0.2717,  0.7071,  0.1872, -0.3307,\n",
      "        -0.2836,  0.1030,  1.2228,  0.8374,  0.1004,  0.0290,  0.2366,  0.1697,\n",
      "         0.1663,  0.1168,  0.1768,  0.2029,  0.2458, -0.2917, -0.2440,  0.2163,\n",
      "         0.1219, -0.1865, -0.0176, -0.1864,  0.1176,  0.1054,  0.1579, -0.1860,\n",
      "        -0.2466, -0.1175,  0.0732, -0.2293,  0.1627,  0.0272, -0.0785,  0.0360,\n",
      "        -0.0057,  0.0218, -0.0729,  0.1934,  0.0903, -0.0927, -0.4069,  0.0892,\n",
      "        -0.0540,  0.1659,  0.0860, -0.0584, -0.2017,  0.0455, -0.0908,  0.1252,\n",
      "        -0.0151,  0.0822, -0.1524, -0.0566, -0.3361,  0.0536],\n",
      "       grad_fn=<CatBackward>)\n",
      "Token: 3 is\n",
      "tensor([-0.5426,  0.4148,  1.0322, -0.4024,  0.4669,  0.2182, -0.0749,  0.4733,\n",
      "         0.0810, -0.2208, -0.1281, -0.1144,  0.5089,  0.1157,  0.0282, -0.3628,\n",
      "         0.4382,  0.0475,  0.2028,  0.4986, -0.1007,  0.1327,  0.1697,  0.1165,\n",
      "         0.3135,  0.2571,  0.0928, -0.5683, -0.5297, -0.0515, -0.6733,  0.9253,\n",
      "         0.2693,  0.2273,  0.6636,  0.2622,  0.1972,  0.2609,  0.1877, -0.3454,\n",
      "        -0.4263,  0.1398,  0.5634, -0.5691,  0.1240, -0.1289,  0.7248, -0.2610,\n",
      "        -0.2631, -0.4360,  0.0789, -0.8415,  0.5160,  1.3997, -0.7646, -3.1453,\n",
      "        -0.2920, -0.3125,  1.5129,  0.5243,  0.2146,  0.4245, -0.0884, -0.1780,\n",
      "         1.1876,  0.1058,  0.7657,  0.2191,  0.3582, -0.1164,  0.0933, -0.6248,\n",
      "        -0.2190,  0.2180,  0.7406, -0.4374,  0.1434,  0.1472, -1.1605, -0.0505,\n",
      "         0.1268, -0.0144, -0.9868, -0.0913, -1.2054, -0.1197,  0.0478, -0.5400,\n",
      "         0.5246, -0.7096, -0.3253, -0.1346, -0.4131,  0.3343, -0.0072,  0.3225,\n",
      "        -0.0442, -1.2969,  0.7622,  0.4635, -0.0254, -0.0266,  0.2035,  0.1153,\n",
      "        -0.0407,  0.1062,  0.0477,  0.1164,  0.2812, -0.2613, -0.2390,  0.2604,\n",
      "         0.0625, -0.1660, -0.0306, -0.1705,  0.1613,  0.1041,  0.1519, -0.1656,\n",
      "        -0.2456, -0.0969,  0.1303, -0.0885,  0.1226,  0.0272, -0.0785,  0.0360,\n",
      "        -0.0057,  0.0218, -0.0729,  0.1934,  0.0903, -0.0927, -0.4069,  0.0892,\n",
      "        -0.0540,  0.1659,  0.0860, -0.0584, -0.2017,  0.0455, -0.0908,  0.1252,\n",
      "        -0.0151,  0.0822, -0.1524, -0.0566, -0.3361,  0.0536],\n",
      "       grad_fn=<CatBackward>)\n",
      "Token: 4 green\n",
      "tensor([-6.7907e-01,  3.4908e-01, -2.3984e-01, -9.9652e-01,  7.3782e-01,\n",
      "        -6.5911e-04,  2.8010e-01,  1.7287e-02, -3.6063e-01,  3.6955e-02,\n",
      "        -4.0395e-01,  2.4092e-02,  2.8958e-01,  4.0497e-01,  6.9992e-01,\n",
      "         2.5269e-01,  8.0350e-01,  4.9370e-02,  1.5562e-01, -6.3286e-03,\n",
      "        -2.9414e-01,  1.4728e-01,  1.8977e-01, -5.1791e-01,  3.6986e-01,\n",
      "         7.4582e-01,  8.2689e-02, -7.2601e-01, -4.0939e-01, -9.7822e-02,\n",
      "        -1.4096e-01,  7.1121e-01,  6.1933e-01, -2.5014e-01,  4.2250e-01,\n",
      "         4.8458e-01, -5.1915e-01,  7.7125e-01,  3.6685e-01,  4.9652e-01,\n",
      "        -4.1298e-02, -1.4683e+00,  2.0038e-01,  1.8591e-01,  4.9860e-02,\n",
      "        -1.7523e-01, -3.5528e-01,  9.4153e-01, -1.1898e-01, -5.1903e-01,\n",
      "        -1.1887e-02, -3.9186e-01, -1.7479e-01,  9.3451e-01, -5.8931e-01,\n",
      "        -2.7701e+00,  3.4522e-01,  8.6533e-01,  1.0808e+00, -1.0291e-01,\n",
      "        -9.1220e-02,  5.5092e-01, -3.9473e-01,  5.3676e-01,  1.0383e+00,\n",
      "        -4.0658e-01,  2.4590e-01, -2.6797e-01, -2.6036e-01, -1.4151e-01,\n",
      "        -1.2022e-01,  1.6234e-01, -7.4320e-01, -6.4728e-01,  4.7133e-02,\n",
      "         5.1642e-01,  1.9898e-01,  2.3919e-01,  1.2550e-01,  2.2471e-01,\n",
      "         8.2613e-01,  7.8328e-02, -5.7020e-01,  2.3934e-02, -1.5410e-01,\n",
      "        -2.5739e-01,  4.1262e-01, -4.6967e-01,  8.7914e-01,  7.2629e-01,\n",
      "         5.3862e-02, -1.1575e+00, -4.7835e-01,  2.0139e-01, -1.0051e+00,\n",
      "         1.1515e-01, -9.6609e-01,  1.2960e-01,  1.8388e-01, -3.0383e-02,\n",
      "         2.3410e-01, -1.4150e-01, -1.4317e-01, -7.9950e-02,  2.2265e-01,\n",
      "        -3.1271e-02,  2.4928e-01,  9.5457e-02,  7.0562e-03,  8.6135e-02,\n",
      "         1.3798e-01, -6.3350e-02,  9.9218e-02, -4.5819e-03,  1.9424e-01,\n",
      "         3.0682e-01,  7.8153e-03,  1.2644e-01,  7.8239e-02,  9.1541e-02,\n",
      "        -3.2165e-02, -9.5144e-02, -1.1466e-01, -3.8280e-02, -7.9813e-02,\n",
      "         7.5818e-03,  1.6530e-01, -7.5781e-02, -5.4557e-02, -7.6738e-02,\n",
      "        -4.6856e-02, -1.0195e-01,  9.9022e-02,  2.4027e-01,  1.0468e-02,\n",
      "         1.9845e-01, -2.1230e-02,  7.1300e-02,  1.7585e-02, -9.3911e-03,\n",
      "        -9.7738e-02,  1.1224e-01,  3.7499e-02, -2.0135e-01,  8.5252e-02,\n",
      "         6.1836e-02, -3.2621e-02,  1.1995e-02, -2.0415e-01, -2.8720e-02],\n",
      "       grad_fn=<CatBackward>)\n",
      "Token: 5 .\n",
      "tensor([-3.3979e-01,  2.0941e-01,  4.6348e-01, -6.4792e-01, -3.8377e-01,\n",
      "         3.8034e-02,  1.7127e-01,  1.5978e-01,  4.6619e-01, -1.9169e-02,\n",
      "         4.1479e-01, -3.4349e-01,  2.6872e-01,  4.4640e-02,  4.2131e-01,\n",
      "        -4.1032e-01,  1.5459e-01,  2.2239e-02, -6.4653e-01,  2.5256e-01,\n",
      "         4.3136e-02, -1.9445e-01,  4.6516e-01,  4.5651e-01,  6.8588e-01,\n",
      "         9.1295e-02,  2.1875e-01, -7.0351e-01,  1.6785e-01, -3.5079e-01,\n",
      "        -1.2634e-01,  6.6384e-01, -2.5820e-01,  3.6542e-02, -1.3605e-01,\n",
      "         4.0253e-01,  1.4289e-01,  3.8132e-01, -1.2283e-01, -4.5886e-01,\n",
      "        -2.5282e-01, -3.0432e-01, -1.1215e-01, -2.6182e-01, -2.2482e-01,\n",
      "        -4.4554e-01,  2.9910e-01, -8.5612e-01, -1.4503e-01, -4.9086e-01,\n",
      "         8.2973e-03, -1.7491e-01,  2.7524e-01,  1.4401e+00, -2.1239e-01,\n",
      "        -2.8435e+00, -2.7958e-01, -4.5722e-01,  1.6386e+00,  7.8808e-01,\n",
      "        -5.5262e-01,  6.5000e-01,  8.6426e-02,  3.9012e-01,  1.0632e+00,\n",
      "        -3.5379e-01,  4.8328e-01,  3.4600e-01,  8.4174e-01,  9.8707e-02,\n",
      "        -2.4213e-01, -2.7053e-01,  4.5287e-02, -4.0147e-01,  1.1395e-01,\n",
      "         6.2226e-03,  3.6673e-02,  1.8518e-02, -1.0213e+00, -2.0806e-01,\n",
      "         6.4072e-01, -6.8763e-02, -5.8635e-01,  3.3476e-01, -1.1432e+00,\n",
      "        -1.1480e-01, -2.5091e-01, -4.5907e-01, -9.6819e-02, -1.7946e-01,\n",
      "        -6.3351e-02, -6.7412e-01, -6.8895e-02,  5.3604e-01, -8.7773e-01,\n",
      "         3.1802e-01, -3.9242e-01, -2.3394e-01,  4.7298e-01, -2.8803e-02,\n",
      "         8.2464e-02, -1.7575e-01, -1.4336e-01,  3.9867e-03, -1.4155e-01,\n",
      "        -7.6877e-03,  1.0880e-03, -6.3159e-02, -7.6448e-02,  8.3365e-02,\n",
      "         3.6257e-03,  7.6893e-03,  1.4932e-02, -3.5098e-03,  2.7587e-02,\n",
      "        -3.3187e-02, -6.8181e-03,  1.6592e-01,  2.3646e-02,  1.7029e-01,\n",
      "        -4.5547e-02, -6.0603e-02,  1.0320e-01, -8.0149e-02, -1.5537e-01,\n",
      "         9.6964e-02, -1.3416e-01, -2.1076e-01, -1.3461e-01,  8.6052e-02,\n",
      "        -1.5016e-01,  1.9833e-01, -1.9856e-02, -1.9699e-01,  3.4966e-02,\n",
      "        -4.7196e-02,  8.6259e-02,  1.0409e-01, -7.2638e-02,  2.0218e-01,\n",
      "        -5.5694e-02,  7.0337e-02,  1.3896e-01,  1.0324e-01, -8.2287e-02,\n",
      "        -7.9263e-02, -6.4011e-02,  2.1714e-03,  5.0975e-02, -1.6845e-02],\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       grad_fn=<CatBackward>)\n"
     ]
    }
   ],
   "source": [
    "for token in sentence:\n",
    "    print(token)\n",
    "    print(token.embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\SusanLi\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\ipykernel_launcher.py:4: DeprecationWarning: Call to deprecated method __init__. (Use 'FlairEmbeddings' instead.) -- Deprecated since version 0.4.\n",
      "  after removing the cwd from sys.path.\n",
      "C:\\Users\\SusanLi\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\ipykernel_launcher.py:5: DeprecationWarning: Call to deprecated method __init__. (Use 'FlairEmbeddings' instead.) -- Deprecated since version 0.4.\n",
      "  \"\"\"\n"
     ]
    }
   ],
   "source": [
    "from flair.embeddings import WordEmbeddings, CharLMEmbeddings, DocumentPoolEmbeddings, Sentence\n",
    "\n",
    "glove_embedding = WordEmbeddings('glove')\n",
    "charlm_embedding_forward = CharLMEmbeddings('news-forward')\n",
    "charlm_embedding_backward = CharLMEmbeddings('news-backward')\n",
    "document_embeddings = DocumentPoolEmbeddings([glove_embedding, \n",
    "                                              charlm_embedding_forward, \n",
    "                                              charlm_embedding_backward])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = Sentence('The grass is green . And the sky is blue .')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "document_embeddings.embed(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.3197,  0.2621,  0.4037,  ..., -0.0008, -0.0051, -0.0109]])\n"
     ]
    }
   ],
   "source": [
    "print(sentence.get_embedding())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "document_embeddings = DocumentPoolEmbeddings([glove_embedding, \n",
    "                                             charlm_embedding_backward,\n",
    "                                             charlm_embedding_forward],\n",
    "                                            mode = 'min')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flair.embeddings import WordEmbeddings, DocumentLSTMEmbeddings\n",
    "\n",
    "glove_embedding = WordEmbeddings('glove')\n",
    "document_embeddings = DocumentLSTMEmbeddings([glove_embedding])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = Sentence('The grass is green . And the sky is blue .')\n",
    "document_embeddings.embed(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0000, -0.2567, -0.3857,  0.0000,  0.0000,  0.4679, -0.0000, -0.0000,\n",
      "         -0.0000,  0.0413,  0.3378, -0.0000, -0.0000, -0.0000,  0.6527, -0.6511,\n",
      "          1.0144, -0.1377,  0.5243, -0.5654,  0.0000, -0.0236,  0.1107,  0.0000,\n",
      "         -0.7132, -0.5130, -0.3489, -0.5734,  0.7072,  0.1158, -0.3548,  0.0000,\n",
      "          0.0000, -0.1011,  0.0743,  0.5346,  0.2456,  0.3685,  0.0000,  0.1319,\n",
      "         -0.6749, -0.0000,  0.0000, -0.3798,  0.4302,  0.0000,  0.1881,  0.4432,\n",
      "         -0.0000,  0.6083, -0.2418,  0.5634, -0.7348,  0.7113, -0.3781, -0.4040,\n",
      "          0.7722, -0.6238,  0.8772,  0.0000,  0.5456,  0.4980,  0.0000,  0.1653,\n",
      "         -0.0000,  0.0553, -0.8303,  0.5382, -0.0000,  0.0000,  0.1737, -0.2544,\n",
      "         -1.0751,  0.0816,  0.0000, -0.6108,  0.0000,  0.7551, -0.0000, -0.0000,\n",
      "         -0.0000,  0.0000, -0.2756,  0.0173,  0.0000, -0.0000,  0.0904,  0.0000,\n",
      "          0.3185, -0.0000,  0.0000, -0.0000, -0.0000, -0.0000,  0.1771, -0.4003,\n",
      "          0.0000,  0.0000, -0.6380, -0.3645, -0.0000,  0.0000,  0.0000,  0.0000,\n",
      "          0.5596,  0.0000, -0.0000, -0.1360,  0.2858, -0.0000, -0.6948, -0.0000,\n",
      "         -1.0255, -0.1839, -0.5161, -0.0000, -0.0000, -0.0791, -0.3432,  0.5404,\n",
      "          0.3125,  0.0000,  0.0000, -0.0419,  0.0000, -0.0000,  0.6848,  0.0000]],\n",
      "       grad_fn=<CatBackward>)\n"
     ]
    }
   ],
   "source": [
    "print(sentence.get_embedding())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'NLPTask' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-55-63dd39d0a6ed>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mflair\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata_fetcher\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mNLPTaskDataFetcher\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mcorpus\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mNLPTaskDataFetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_corpus\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNLPTask\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mUD_ENGLISH\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m: name 'NLPTask' is not defined"
     ]
    }
   ],
   "source": [
    "from flair.data import TaggedCorpus\n",
    "from flair.data_fetcher import NLPTaskDataFetcher\n",
    "\n",
    "corpus = NLPTaskDataFetcher.load_corpus(NLPTask.UD_ENGLISH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "storage has wrong size: expected -1862414276 got 22700",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-56-b9335937c805>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mflair\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mflair\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mSentence\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mclassifier\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'en-sentiment'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      4\u001b[0m \u001b[0msentence\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mSentence\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Flair is pretty neat!'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0mclassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msentence\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\flair\\models\\text_classification_model.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(model)\u001b[0m\n\u001b[0;32m    277\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    278\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mmodel_file\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 279\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_from_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\flair\\models\\text_classification_model.py\u001b[0m in \u001b[0;36mload_from_file\u001b[1;34m(cls, model_file)\u001b[0m\n\u001b[0;32m    104\u001b[0m         \u001b[1;33m:\u001b[0m\u001b[1;32mreturn\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mloaded\u001b[0m \u001b[0mtext\u001b[0m \u001b[0mclassifier\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    105\u001b[0m         \"\"\"\n\u001b[1;32m--> 106\u001b[1;33m         \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_load_state\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    107\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    108\u001b[0m         model = TextClassifier(\n",
      "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\flair\\models\\text_classification_model.py\u001b[0m in \u001b[0;36m_load_state\u001b[1;34m(cls, model_file)\u001b[0m\n\u001b[0;32m    144\u001b[0m                 \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    145\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 146\u001b[1;33m                 \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m'cuda:0'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m'cpu'\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    147\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mstate\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    148\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module)\u001b[0m\n\u001b[0;32m    365\u001b[0m         \u001b[0mf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'rb'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    366\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 367\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    368\u001b[0m     \u001b[1;32mfinally\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    369\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mnew_fd\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36m_load\u001b[1;34m(f, map_location, pickle_module)\u001b[0m\n\u001b[0;32m    543\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdeserialized_storage_keys\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    544\u001b[0m         \u001b[1;32massert\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdeserialized_objects\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 545\u001b[1;33m         \u001b[0mdeserialized_objects\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_set_from_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moffset\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mf_should_read_directly\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    546\u001b[0m         \u001b[0moffset\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    547\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: storage has wrong size: expected -1862414276 got 22700"
     ]
    }
   ],
   "source": [
    "from flair.models import TextClassifier\n",
    "from flair.data import Sentence\n",
    "classifier = TextClassifier.load('en-sentiment')\n",
    "sentence = Sentence('Flair is pretty neat!')\n",
    "classifier.predict(sentence)\n",
    "# print sentence with predicted labels\n",
    "print('Sentence above is: ', sentence.labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
